{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "centernet_on_mobile.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pDIqEDDxWAh2"
      },
      "source": [
        "#Introduction\r\n",
        "\r\n",
        "Welcome to the **CenterNet on-device with TensorFlow Lite** Colab. Here, we demonstrate how you can run a mobile-optimized version of the [CenterNet](https://arxiv.org/abs/1904.08189) architecture with [TensorFlow Lite](https://www.tensorflow.org/lite) (a.k.a.  TFLite). \r\n",
        "\r\n",
        "Users can use this notebook as a reference for obtaining TFLite version of CenterNet for *Object Detection* or [*Keypoint detection*](https://cocodataset.org/#keypoints-2020). The code also shows how to perform pre-/post-processing & inference with TFLite's Python API.\r\n",
        "\r\n",
        "**NOTE:** CenterNet support in TFLite is still experimental, and currently works with floating-point inference only."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3LQWTJ-BWzmW"
      },
      "source": [
        "# Set Up"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gx84EpH7INPj"
      },
      "source": [
        "## Libraries & Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EU_hXi7IW9QC"
      },
      "source": [
        "!pip install tf-nightly"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZTU9_JcOZz-J"
      },
      "source": [
        "import os\r\n",
        "import pathlib\r\n",
        "\r\n",
        "# Clone the tensorflow models repository if it doesn't already exist\r\n",
        "if \"models\" in pathlib.Path.cwd().parts:\r\n",
        "  while \"models\" in pathlib.Path.cwd().parts:\r\n",
        "    os.chdir('..')\r\n",
        "elif not pathlib.Path('models').exists():\r\n",
        "  !git clone --depth 1 https://github.com/tensorflow/models"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "laJxis1WZ2xj"
      },
      "source": [
        "# Install the Object Detection API\r\n",
        "%%bash\r\n",
        "cd models/research/\r\n",
        "protoc object_detection/protos/*.proto --python_out=.\r\n",
        "cp object_detection/packages/tf2/setup.py .\r\n",
        "python -m pip install ."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "je0LrJNjDsk9"
      },
      "source": [
        "import matplotlib\r\n",
        "import matplotlib.pyplot as plt\r\n",
        "\r\n",
        "import os\r\n",
        "import random\r\n",
        "import io\r\n",
        "import imageio\r\n",
        "import glob\r\n",
        "import scipy.misc\r\n",
        "import numpy as np\r\n",
        "from six import BytesIO\r\n",
        "from PIL import Image, ImageDraw, ImageFont\r\n",
        "from IPython.display import display, Javascript\r\n",
        "from IPython.display import Image as IPyImage\r\n",
        "\r\n",
        "import tensorflow as tf\r\n",
        "\r\n",
        "from object_detection.utils import label_map_util\r\n",
        "from object_detection.utils import config_util\r\n",
        "from object_detection.utils import visualization_utils as viz_utils\r\n",
        "from object_detection.utils import colab_utils\r\n",
        "from object_detection.utils import config_util\r\n",
        "from object_detection.builders import model_builder\r\n",
        "\r\n",
        "%matplotlib inline"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O5IXwbhhH0bs"
      },
      "source": [
        "## Test Image from COCO\r\n",
        "\r\n",
        "We use a sample image from the COCO'17 validation dataset that contains people, to showcase inference with CenterNet."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h-JuG84HDvm-"
      },
      "source": [
        "# Download COCO'17 validation set for test image\r\n",
        "%%bash\r\n",
        "mkdir -p coco && cd coco\r\n",
        "wget -q -N http://images.cocodataset.org/zips/val2017.zip\r\n",
        "unzip -q -o val2017.zip && rm *.zip\r\n",
        "cd .."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "peX5mPGmEj8s"
      },
      "source": [
        "# Print the image we are going to test on as a sanity check.\r\n",
        "\r\n",
        "def load_image_into_numpy_array(path):\r\n",
        "  \"\"\"Load an image from file into a numpy array.\r\n",
        "\r\n",
        "  Puts image into numpy array to feed into tensorflow graph.\r\n",
        "  Note that by convention we put it into a numpy array with shape\r\n",
        "  (height, width, channels), where channels=3 for RGB.\r\n",
        "\r\n",
        "  Args:\r\n",
        "    path: a file path.\r\n",
        "\r\n",
        "  Returns:\r\n",
        "    uint8 numpy array with shape (img_height, img_width, 3)\r\n",
        "  \"\"\"\r\n",
        "  img_data = tf.io.gfile.GFile(path, 'rb').read()\r\n",
        "  image = Image.open(BytesIO(img_data))\r\n",
        "  (im_width, im_height) = image.size\r\n",
        "  return np.array(image.getdata()).reshape(\r\n",
        "      (im_height, im_width, 3)).astype(np.uint8)\r\n",
        "\r\n",
        "image_path = 'coco/val2017/000000013729.jpg'\r\n",
        "plt.figure(figsize = (30, 20))\r\n",
        "plt.imshow(load_image_into_numpy_array(image_path))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6cqOdvfrR1vW"
      },
      "source": [
        "## Utilities for Inference\r\n",
        "\r\n",
        "The `detect` function shown below describes how input and output tensors from CenterNet (obtained in subsequent sections) can be processed. This logic can be ported to other languages depending on your application (for e.g. to Java for Android apps)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "13Ouc2C3RyqR"
      },
      "source": [
        "def detect(interpreter, input_tensor, include_keypoint=False):\r\n",
        "  \"\"\"Run detection on an input image.\r\n",
        "\r\n",
        "  Args:\r\n",
        "    interpreter: tf.lite.Interpreter\r\n",
        "    input_tensor: A [1, height, width, 3] Tensor of type tf.float32.\r\n",
        "      Note that height and width can be anything since the image will be\r\n",
        "      immediately resized according to the needs of the model within this\r\n",
        "      function.\r\n",
        "    include_keypoint: True if model supports keypoints output. See\r\n",
        "      https://cocodataset.org/#keypoints-2020\r\n",
        "\r\n",
        "  Returns:\r\n",
        "    A sequence containing the following output tensors:\r\n",
        "      boxes: a numpy array of shape [N, 4]\r\n",
        "      classes: a numpy array of shape [N]. Note that class indices are \r\n",
        "        1-based, and match the keys in the label map.\r\n",
        "      scores: a numpy array of shape [N] or None.  If scores=None, then\r\n",
        "        this function assumes that the boxes to be plotted are groundtruth\r\n",
        "        boxes and plot all boxes as black with no classes or scores.\r\n",
        "      category_index: a dict containing category dictionaries (each holding\r\n",
        "        category index `id` and category name `name`) keyed by category \r\n",
        "        indices.\r\n",
        "    If include_keypoints is True, the following are also returned:\r\n",
        "      keypoints: (optional) a numpy array of shape [N, 17, 2] representing\r\n",
        "        the yx-coordinates of the detection 17 COCO human keypoints\r\n",
        "        (https://cocodataset.org/#keypoints-2020) in normalized image frame\r\n",
        "        (i.e. [0.0, 1.0]). \r\n",
        "      keypoint_scores: (optional) a numpy array of shape [N, 17] representing the\r\n",
        "        keypoint prediction confidence scores.\r\n",
        "  \"\"\"\r\n",
        "  input_details = interpreter.get_input_details()\r\n",
        "  output_details = interpreter.get_output_details()\r\n",
        "\r\n",
        "  interpreter.set_tensor(input_details[0]['index'], input_tensor.numpy())\r\n",
        "\r\n",
        "  interpreter.invoke()\r\n",
        "\r\n",
        "  scores = interpreter.get_tensor(output_details[0]['index'])\r\n",
        "  boxes = interpreter.get_tensor(output_details[1]['index'])\r\n",
        "  num_detections = interpreter.get_tensor(output_details[2]['index'])\r\n",
        "  classes = interpreter.get_tensor(output_details[3]['index'])\r\n",
        "\r\n",
        "  if include_keypoint:\r\n",
        "    kpts = interpreter.get_tensor(output_details[4]['index'])\r\n",
        "    kpts_scores = interpreter.get_tensor(output_details[5]['index'])\r\n",
        "    return boxes, classes, scores, num_detections, kpts, kpts_scores\r\n",
        "  else:\r\n",
        "    return boxes, classes, scores, num_detections\r\n",
        "\r\n",
        "# Utility for visualizing results\r\n",
        "def plot_detections(image_np,\r\n",
        "                    boxes,\r\n",
        "                    classes,\r\n",
        "                    scores,\r\n",
        "                    category_index,\r\n",
        "                    keypoints=None,\r\n",
        "                    keypoint_scores=None,\r\n",
        "                    figsize=(12, 16),\r\n",
        "                    image_name=None):\r\n",
        "  \"\"\"Wrapper function to visualize detections.\r\n",
        "\r\n",
        "  Args:\r\n",
        "    image_np: uint8 numpy array with shape (img_height, img_width, 3)\r\n",
        "    boxes: a numpy array of shape [N, 4]\r\n",
        "    classes: a numpy array of shape [N]. Note that class indices are 1-based,\r\n",
        "      and match the keys in the label map.\r\n",
        "    scores: a numpy array of shape [N] or None.  If scores=None, then\r\n",
        "      this function assumes that the boxes to be plotted are groundtruth\r\n",
        "      boxes and plot all boxes as black with no classes or scores.\r\n",
        "    category_index: a dict containing category dictionaries (each holding\r\n",
        "      category index `id` and category name `name`) keyed by category indices.\r\n",
        "    keypoints: (optional) a numpy array of shape [N, 17, 2] representing the \r\n",
        "      yx-coordinates of the detection 17 COCO human keypoints\r\n",
        "      (https://cocodataset.org/#keypoints-2020) in normalized image frame\r\n",
        "      (i.e. [0.0, 1.0]). \r\n",
        "    keypoint_scores: (optional) anumpy array of shape [N, 17] representing the\r\n",
        "      keypoint prediction confidence scores.\r\n",
        "    figsize: size for the figure.\r\n",
        "    image_name: a name for the image file.\r\n",
        "  \"\"\"\r\n",
        "\r\n",
        "  keypoint_edges = [(0, 1),\r\n",
        "        (0, 2),\r\n",
        "        (1, 3),\r\n",
        "        (2, 4),\r\n",
        "        (0, 5),\r\n",
        "        (0, 6),\r\n",
        "        (5, 7),\r\n",
        "        (7, 9),\r\n",
        "        (6, 8),\r\n",
        "        (8, 10),\r\n",
        "        (5, 6),\r\n",
        "        (5, 11),\r\n",
        "        (6, 12),\r\n",
        "        (11, 12),\r\n",
        "        (11, 13),\r\n",
        "        (13, 15),\r\n",
        "        (12, 14),\r\n",
        "        (14, 16)]\r\n",
        "  image_np_with_annotations = image_np.copy()\r\n",
        "  # Only visualize objects that get a score > 0.3.\r\n",
        "  viz_utils.visualize_boxes_and_labels_on_image_array(\r\n",
        "      image_np_with_annotations,\r\n",
        "      boxes,\r\n",
        "      classes,\r\n",
        "      scores,\r\n",
        "      category_index,\r\n",
        "      keypoints=keypoints,\r\n",
        "      keypoint_scores=keypoint_scores,\r\n",
        "      keypoint_edges=keypoint_edges,\r\n",
        "      use_normalized_coordinates=True,\r\n",
        "      min_score_thresh=0.3)\r\n",
        "  if image_name:\r\n",
        "    plt.imsave(image_name, image_np_with_annotations)\r\n",
        "  else:\r\n",
        "    return image_np_with_annotations"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3cNYi8HuIWzO"
      },
      "source": [
        "# Object Detection"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "azdDCdWQMSoH"
      },
      "source": [
        "## Download Model from Detection Zoo\r\n",
        "\r\n",
        "**NOTE:** Not all CenterNet models from the [TF2 Detection Zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md) work with TFLite, only the [MobileNet-based version](http://download.tensorflow.org/models/object_detection/tf2/20210210/centernet_mobilenetv2fpn_512x512_coco17_od.tar.gz) does.\r\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Sywt8MKzIeOi"
      },
      "source": [
        "# Get mobile-friendly CenterNet for Object Detection\r\n",
        "# See TensorFlow 2 Detection Model Zoo for more details:\r\n",
        "# https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md\r\n",
        "\r\n",
        "%%bash\r\n",
        "wget http://download.tensorflow.org/models/object_detection/tf2/20210210/centernet_mobilenetv2fpn_512x512_coco17_od.tar.gz\r\n",
        "tar -xf centernet_mobilenetv2fpn_512x512_coco17_od.tar.gz\r\n",
        "rm centernet_mobilenetv2fpn_512x512_coco17_od.tar.gz*"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MiRrVpTnLvsk"
      },
      "source": [
        "Now that we have downloaded the CenterNet model that uses MobileNet as a backbone, we can obtain a TensorFlow Lite model from it. \r\n",
        "\r\n",
        "The downloaded archive already contains `model.tflite` that works with TensorFlow Lite, but we re-generate the model in the next sub-section to account for cases where you might re-train the model on your own dataset (with corresponding changes to `pipeline.config` & `checkpoint` directory)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jT0bruuxM496"
      },
      "source": [
        "## Generate TensorFlow Lite Model\r\n",
        "\r\n",
        "First, we invoke `export_tflite_graph_tf2.py` to generate a TFLite-friendly intermediate SavedModel. This will then be passed to the TensorFlow Lite Converter for generating the final model.\r\n",
        "\r\n",
        "This is similar to what we do for [SSD architectures](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tf2.md)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jpcCjiQ_JrU5",
        "collapsed": true
      },
      "source": [
        "%%bash\r\n",
        "# Export the intermediate SavedModel that outputs 10 detections & takes in an \r\n",
        "# image of dim 320x320.\r\n",
        "# Modify these parameters according to your needs.\r\n",
        "\r\n",
        "python models/research/object_detection/export_tflite_graph_tf2.py \\\r\n",
        "  --pipeline_config_path=centernet_mobilenetv2_fpn_od/pipeline.config \\\r\n",
        "  --trained_checkpoint_dir=centernet_mobilenetv2_fpn_od/checkpoint \\\r\n",
        "  --output_directory=centernet_mobilenetv2_fpn_od/tflite \\\r\n",
        "  --centernet_include_keypoints=false \\\r\n",
        "  --max_detections=10 \\\r\n",
        "  --config_override=\" \\\r\n",
        "    model{ \\\r\n",
        "      center_net { \\\r\n",
        "        image_resizer { \\\r\n",
        "          fixed_shape_resizer { \\\r\n",
        "            height: 320 \\\r\n",
        "            width: 320 \\\r\n",
        "          } \\\r\n",
        "        } \\\r\n",
        "      } \\\r\n",
        "    }\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zhhP6HL8PUGq"
      },
      "source": [
        "# Generate TensorFlow Lite model using the converter.\r\n",
        "%%bash\r\n",
        "tflite_convert --output_file=centernet_mobilenetv2_fpn_od/model.tflite \\\r\n",
        "  --saved_model_dir=centernet_mobilenetv2_fpn_od/tflite/saved_model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gj1Q_e_2Rn5i"
      },
      "source": [
        "## TensorFlow Lite Inference\r\n",
        "\r\n",
        "Use a TensorFlow Lite Interpreter to detect objects in the test image."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uV9t9icURsei"
      },
      "source": [
        "%matplotlib inline\r\n",
        "\r\n",
        "# Load the TFLite model and allocate tensors.\r\n",
        "model_path = 'centernet_mobilenetv2_fpn_od/model.tflite'\r\n",
        "label_map_path = 'centernet_mobilenetv2_fpn_od/label_map.txt'\r\n",
        "image_path = 'coco/val2017/000000013729.jpg'\r\n",
        "\r\n",
        "# Initialize TensorFlow Lite Interpreter.\r\n",
        "interpreter = tf.lite.Interpreter(model_path=model_path)\r\n",
        "interpreter.allocate_tensors()\r\n",
        "\r\n",
        "# Label map can be used to figure out what class ID maps to what\r\n",
        "# label. `label_map.txt` is human-readable.\r\n",
        "category_index = label_map_util.create_category_index_from_labelmap(\r\n",
        "    label_map_path)\r\n",
        "\r\n",
        "label_id_offset = 1\r\n",
        "\r\n",
        "image = tf.io.read_file(image_path)\r\n",
        "image = tf.compat.v1.image.decode_jpeg(image)\r\n",
        "image = tf.expand_dims(image, axis=0)\r\n",
        "image_numpy = image.numpy()\r\n",
        "\r\n",
        "input_tensor = tf.convert_to_tensor(image_numpy, dtype=tf.float32)\r\n",
        "# Note that CenterNet doesn't require any pre-processing except resizing to the\r\n",
        "# input size that the TensorFlow Lite Interpreter was generated with.\r\n",
        "input_tensor = tf.image.resize(input_tensor, (320, 320))\r\n",
        "boxes, classes, scores, num_detections = detect(interpreter, input_tensor)\r\n",
        "\r\n",
        "vis_image = plot_detections(\r\n",
        "    image_numpy[0],\r\n",
        "    boxes[0],\r\n",
        "    classes[0].astype(np.uint32) + label_id_offset,\r\n",
        "    scores[0],\r\n",
        "    category_index)\r\n",
        "plt.figure(figsize = (30, 20))\r\n",
        "plt.imshow(vis_image)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DefXu4JXVxPD"
      },
      "source": [
        "# Keypoints\r\n",
        "\r\n",
        "Unlike SSDs, CenterNet also supports COCO [Keypoint detection](https://cocodataset.org/#keypoints-2020). To be more specific, the 'keypoints' version of CenterNet shown here provides keypoints as a `[N, 17, 2]`-shaped tensor representing the (normalized) yx-coordinates of 17 COCO human keypoints.\r\n",
        "\r\n",
        "See the `detect()` function in the **Utilities for Inference** section to better understand the keypoints output."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xu47DkrDV18O"
      },
      "source": [
        "## Download Model from Detection Zoo\r\n",
        "\r\n",
        "**NOTE:** Not all CenterNet models from the [TF2 Detection Zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md) work with TFLite, only the [MobileNet-based version](http://download.tensorflow.org/models/object_detection/tf2/20210210/centernet_mobilenetv2fpn_512x512_coco17_od.tar.gz) does."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sd7f64WjWD7z"
      },
      "source": [
        "# Get mobile-friendly CenterNet for Keypoint detection task.\r\n",
        "# See TensorFlow 2 Detection Model Zoo for more details:\r\n",
        "# https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md\r\n",
        "\r\n",
        "%%bash\r\n",
        "wget http://download.tensorflow.org/models/object_detection/tf2/20210210/centernet_mobilenetv2fpn_512x512_coco17_kpts.tar.gz\r\n",
        "tar -xf centernet_mobilenetv2fpn_512x512_coco17_kpts.tar.gz\r\n",
        "rm centernet_mobilenetv2fpn_512x512_coco17_kpts.tar.gz*"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NSFc-xSLX1ZC"
      },
      "source": [
        "## Generate TensorFlow Lite Model\r\n",
        "\r\n",
        "As before, we leverage `export_tflite_graph_tf2.py` to generate a TFLite-friendly intermediate SavedModel. This will then be passed to the TFLite converter to generating the final model.\r\n",
        "\r\n",
        "Note that we need to include an additional `keypoint_label_map_path` parameter for exporting the keypoints outputs."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8kEhwYynX-cD"
      },
      "source": [
        "%%bash\r\n",
        "# Export the intermediate SavedModel that outputs 10 detections & takes in an \r\n",
        "# image of dim 320x320.\r\n",
        "# Modify these parameters according to your needs.\r\n",
        "\r\n",
        "python models/research/object_detection/export_tflite_graph_tf2.py \\\r\n",
        "  --pipeline_config_path=centernet_mobilenetv2_fpn_kpts/pipeline.config \\\r\n",
        "  --trained_checkpoint_dir=centernet_mobilenetv2_fpn_kpts/checkpoint \\\r\n",
        "  --output_directory=centernet_mobilenetv2_fpn_kpts/tflite \\\r\n",
        "  --centernet_include_keypoints=true \\\r\n",
        "  --keypoint_label_map_path=centernet_mobilenetv2_fpn_kpts/label_map.txt \\\r\n",
        "  --max_detections=10 \\\r\n",
        "  --config_override=\" \\\r\n",
        "    model{ \\\r\n",
        "      center_net { \\\r\n",
        "        image_resizer { \\\r\n",
        "          fixed_shape_resizer { \\\r\n",
        "            height: 320 \\\r\n",
        "            width: 320 \\\r\n",
        "          } \\\r\n",
        "        } \\\r\n",
        "      } \\\r\n",
        "    }\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TJtsyMlLY1DU"
      },
      "source": [
        "# Generate TensorFlow Lite model using the converter.\r\n",
        "\r\n",
        "%%bash\r\n",
        "tflite_convert --output_file=centernet_mobilenetv2_fpn_kpts/model.tflite \\\r\n",
        "  --saved_model_dir=centernet_mobilenetv2_fpn_kpts/tflite/saved_model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nJCxPBjYZSk6"
      },
      "source": [
        "## TensorFlow Lite Inference\r\n",
        "\r\n",
        "Use a TensorFlow Lite Interpreter to detect people & their keypoints in the test image."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "F2GpD7r8ZUzx"
      },
      "source": [
        "%matplotlib inline\r\n",
        "\r\n",
        "# Load the TFLite model and allocate tensors.\r\n",
        "model_path = 'centernet_mobilenetv2_fpn_kpts/model.tflite'\r\n",
        "image_path = 'coco/val2017/000000013729.jpg'\r\n",
        "\r\n",
        "# Initialize TensorFlow Lite Interpreter.\r\n",
        "interpreter = tf.lite.Interpreter(model_path=model_path)\r\n",
        "interpreter.allocate_tensors()\r\n",
        "\r\n",
        "# Keypoints are only relevant for people, so we only care about that\r\n",
        "# category Id here.\r\n",
        "category_index = {1: {'id': 1, 'name': 'person'}}\r\n",
        "\r\n",
        "label_id_offset = 1\r\n",
        "\r\n",
        "image = tf.io.read_file(image_path)\r\n",
        "image = tf.compat.v1.image.decode_jpeg(image)\r\n",
        "image = tf.expand_dims(image, axis=0)\r\n",
        "image_numpy = image.numpy()\r\n",
        "\r\n",
        "input_tensor = tf.convert_to_tensor(image_numpy, dtype=tf.float32)\r\n",
        "# Note that CenterNet doesn't require any pre-processing except resizing to\r\n",
        "# input size that the TensorFlow Lite Interpreter was generated with.\r\n",
        "input_tensor = tf.image.resize(input_tensor, (320, 320))\r\n",
        "(boxes, classes, scores, num_detections, kpts, kpts_scores) = detect(\r\n",
        "    interpreter, input_tensor, include_keypoint=True)\r\n",
        "\r\n",
        "vis_image = plot_detections(\r\n",
        "    image_numpy[0],\r\n",
        "    boxes[0],\r\n",
        "    classes[0].astype(np.uint32) + label_id_offset,\r\n",
        "    scores[0],\r\n",
        "    category_index,\r\n",
        "    keypoints=kpts[0],\r\n",
        "    keypoint_scores=kpts_scores[0])\r\n",
        "plt.figure(figsize = (30, 20))\r\n",
        "plt.imshow(vis_image)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "59Y3q6AC9C7s"
      },
      "source": [
        "# Running On Mobile\r\n",
        "\r\n",
        "As mentioned earlier, both the above models can be run on mobile phones with TensorFlow Lite. See our [**inference documentation**](https://www.tensorflow.org/lite/guide/inference) for general guidelines on platform-specific APIs & leveraging hardware acceleration. Both the object-detection & keypoint-detection versions of CenterNet are compatible with our [GPU delegate](https://www.tensorflow.org/lite/performance/gpu). *We are working on developing quantized versions of this model.*\r\n",
        "\r\n",
        "To leverage *object-detection* in your Android app, the simplest way is to use TFLite's [**ObjectDetector Task API**](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector). It is a high-level API that encapsulates complex but common image processing and post processing logic. Inference can be done in 5 lines of code. It is supported in Java for Android and C++ for native code. *We are working on building the Swift API for iOS, as well as the support for the keypoint-detection model.*\r\n",
        "\r\n",
        "To use the Task API, the model needs to be packed with [TFLite Metadata](https://www.tensorflow.org/lite/convert/metadata). This metadata helps the inference code perform the correct pre & post processing as required by the model. Use the following code to create the metadata."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8T_qzv6lDN_a"
      },
      "source": [
        "!pip install tflite_support_nightly"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CTZhBmfWDQ3z"
      },
      "source": [
        "from tflite_support.metadata_writers import object_detector\n",
        "from tflite_support.metadata_writers import writer_utils\n",
        "\n",
        "ObjectDetectorWriter = object_detector.MetadataWriter\n",
        "\n",
        "_MODEL_PATH = \"centernet_mobilenetv2_fpn_od/model.tflite\"\n",
        "_SAVE_TO_PATH = \"centernet_mobilenetv2_fpn_od/model_with_metadata.tflite\"\n",
        "_LABEL_PATH = \"centernet_mobilenetv2_fpn_od/tflite_label_map.txt\"\n",
        "\n",
        "# We need to convert Detection API's labelmap into what the Task API needs:\n",
        "# a txt file with one class name on each line from index 0 to N.\n",
        "# The first '0' class indicates the background.\n",
        "# This code assumes COCO detection which has 90 classes, you can write a label\n",
        "# map file for your model if re-trained.\n",
        "od_label_map_path = 'centernet_mobilenetv2_fpn_od/label_map.txt'\n",
        "category_index = label_map_util.create_category_index_from_labelmap(\n",
        "    label_map_path)\n",
        "f = open(_LABEL_PATH, 'w')\n",
        "for class_id in range(1, 91):\n",
        "  if class_id not in category_index:\n",
        "    f.write('???\\n')\n",
        "    continue\n",
        "  name = category_index[class_id]['name']\n",
        "  f.write(name+'\\n')\n",
        "f.close()\n",
        "\n",
        "writer = ObjectDetectorWriter.create_for_inference(\n",
        "    writer_utils.load_file(_MODEL_PATH), input_norm_mean=[0], \n",
        "    input_norm_std=[1], label_file_paths=[_LABEL_PATH])\n",
        "writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b2tc7awzDUHr"
      },
      "source": [
        "Visualize the metadata just created by the following code:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_SRqVdZNDYF1"
      },
      "source": [
        "from tflite_support import metadata\n",
        "\n",
        "displayer = metadata.MetadataDisplayer.with_model_file(_SAVE_TO_PATH)\n",
        "print(\"Metadata populated:\")\n",
        "print(displayer.get_metadata_json())\n",
        "print(\"=============================\")\n",
        "print(\"Associated file(s) populated:\")\n",
        "print(displayer.get_packed_associated_file_list())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SPUNsg9eDjWT"
      },
      "source": [
        "See more information about *object-detection* models from our [public documentation](https://www.tensorflow.org/lite/examples/object_detection/overview). The [Object Detection example app](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection) is a good starting point for integrating that model into your Android and iOS app. You can find [examples](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android#switch-between-inference-solutions-task-library-vs-tflite-interpreter) of using both the TFLite Task Library and TFLite Interpreter API."
      ]
    }
  ]
}
